
from yacs.config import CfgNode as CN
from .model_path import MODEL_PATH

# -----------------------------------------------------------------------------
# Config definition
# -----------------------------------------------------------------------------

_C = CN()

_C.MODEL = CN()
_C.MODEL.DEVICE = "cuda"

_C.MODEL.BACKBONE = CN()
_C.MODEL.BACKBONE.NAME = "bninception"

_C.MODEL.PRETRAIN = 'imagenet'
_C.MODEL.PRETRIANED_PATH = MODEL_PATH

_C.MODEL.HEAD = CN()
_C.MODEL.HEAD.NAME0 = "linear_norm"
_C.MODEL.HEAD.DIM = 128
_C.MODEL.HEAD.trans_in_channels = [64,128,320,512]
_C.MODEL.HEAD.trans_in_index = [0,1,2,3]
_C.MODEL.HEAD.trans_channels = 128
_C.MODEL.HEAD.trans_embdding_dim =  1024
_C.MODEL.HEAD.NAME3 = "linear_norm_map"
_C.MODEL.HEAD.num_clusters = [1,1,1,64]
_C.MODEL.HEAD.feature_dim = 512 #64,128,320,512
_C.MODEL.HEAD.vladv2 = False
_C.MODEL.HEAD.normalize_input = True
_C.MODEL.HEAD.use_faiss  = True
_C.MODEL.HEAD.patch_sizes='4'
_C.MODEL.HEAD.strides='1'

_C.MODEL.WEIGHT = ""

# Checkpoint save dir
_C.SAVE_DIR = 'output'

# loss
_C.LOSSES = CN()
_C.LOSSES.NAME = 'ms_loss' #'ms_loss'
_C.LOSSES.TRANS_NAME = 'ms_trans_loss'
_C.LOSSES.NAME_AUX = ''
_C.LOSSES.AUX_WEIGHT = 0.01

# proxy_anchor_loss
_C.LOSSES.PROXY_ANCHOR_LOSS = CN()
_C.LOSSES.PROXY_ANCHOR_LOSS.nb_classes = 100
_C.LOSSES.PROXY_ANCHOR_LOSS.sz_embed = 512

# cbml_loss
_C.LOSSES.CBML_LOSS = CN()
_C.LOSSES.CBML_LOSS.POS_A = 0.5
_C.LOSSES.CBML_LOSS.POS_B = 0.5
_C.LOSSES.CBML_LOSS.NEG_A = 1.0
_C.LOSSES.CBML_LOSS.NEG_B = 0.01
_C.LOSSES.CBML_LOSS.MARGIN = 0.1
_C.LOSSES.CBML_LOSS.WEIGHT = 1.0
_C.LOSSES.CBML_LOSS.HYPER_WEIGHT = 0.1
_C.LOSSES.CBML_LOSS.WEIGHT_P = 0.25
_C.LOSSES.CBML_LOSS.WEIGHT_N = 0.25
_C.LOSSES.CBML_LOSS.ADAPTIVE_NEG = True
_C.LOSSES.CBML_LOSS.TYPE = 'log' # log or sqrt or constant

# softtriple_loss
_C.LOSSES.SOFTTRIPLE_LOSS = CN()
_C.LOSSES.SOFTTRIPLE_LOSS.LA = 20
_C.LOSSES.SOFTTRIPLE_LOSS.GAMMA = 0.1
_C.LOSSES.SOFTTRIPLE_LOSS.TAU = 0.2
_C.LOSSES.SOFTTRIPLE_LOSS.MARGIN = 0.01
_C.LOSSES.SOFTTRIPLE_LOSS.K = 2
_C.LOSSES.SOFTTRIPLE_LOSS.CLUSTERS = 100 # cars 98, cub 100, sop 11318, clothes 3997

# rank_loss
_C.LOSSES.RANKED_LIST_LOSS = CN()
_C.LOSSES.RANKED_LIST_LOSS.MARGIN = 0.4
_C.LOSSES.RANKED_LIST_LOSS.ALPHA = 1.4
_C.LOSSES.RANKED_LIST_LOSS.MARGIN_GNN = 0.4
_C.LOSSES.RANKED_LIST_LOSS.ALPHA_GNN = 1.4
_C.LOSSES.RANKED_LIST_LOSS.TVAL = 10

# ms_loss
_C.LOSSES.MULTI_SIMILARITY_LOSS = CN()
_C.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_POS = 2.0
_C.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_NEG = 40.0
_C.LOSSES.MULTI_SIMILARITY_LOSS.HARD_MINING = True

# ms_trans_loss'
_C.LOSSES.MULTI_SIMILARITY_TRANS_LOSS = CN()
_C.LOSSES.MULTI_SIMILARITY_TRANS_LOSS.sigma = 1
_C.LOSSES.MULTI_SIMILARITY_TRANS_LOSS.delta = 1

# margin_loss
_C.LOSSES.MARGIN_LOSS = CN()
_C.LOSSES.MARGIN_LOSS.BETA_CONSTANT = False
_C.LOSSES.MARGIN_LOSS.N_CLASSES = 100
_C.LOSSES.MARGIN_LOSS.BETA_CONSTANT = False
_C.LOSSES.MARGIN_LOSS.CUTOFF = 0.5
_C.LOSSES.MARGIN_LOSS.UPPER_CUTOFF = 1.4

# adv_loss
_C.LOSSES.ADV_LOSS = CN()
_C.LOSSES.ADV_LOSS.CLASS_DIM=256
_C.LOSSES.ADV_LOSS.AUX_DIM=256
_C.LOSSES.ADV_LOSS.PROJ_DIM=256

# proxynca_loss
_C.LOSSES.PROXY_LOSS = CN()
_C.LOSSES.PROXY_LOSS.NB_CLASSES=100 # cars 98, cub 100, sop 11318, clothes 3997
_C.LOSSES.PROXY_LOSS.SMOOTHING_CONST=0.1
_C.LOSSES.PROXY_LOSS.SCALING_X=3 # SOP 1
_C.LOSSES.PROXY_LOSS.SCALING_P=3 # SOP 8

# npair_loss
_C.LOSSES.NPAIR_LOSS = CN()
_C.LOSSES.NPAIR_LOSS.L2_REG=0.02

# angular_loss
_C.LOSSES.ANGULAR_LOSS = CN()
_C.LOSSES.ANGULAR_LOSS.L2_REG=0.02
_C.LOSSES.ANGULAR_LOSS.ANGLE_BOUND = 1.0
_C.LOSSES.ANGULAR_LOSS.LAMBDA_AND=2

# contrastive_loss, online_contrastive_loss
_C.LOSSES.CONTRASTIVE_LOSS = CN()
_C.LOSSES.CONTRASTIVE_LOSS.MARGIN = 1.0
_C.LOSSES.CONTRASTIVE_LOSS.PAIR_SELECTOR='all' # all, hard

# triplet_loss, online_triplet_loss
_C.LOSSES.TRIPLET_LOSS = CN()
_C.LOSSES.TRIPLET_LOSS.MARGIN = 1.0
_C.LOSSES.TRIPLET_LOSS.TRIPLET_SELECTOR='all' # all, hard
_C.LOSSES.TRIPLET_LOSS.NEGATIVE_SELECTION_FN='hardest' # hardest, random, semihard

# cluster_loss, cluster_loss_local
_C.LOSSES.CLUSTER_LOSS = CN()
_C.LOSSES.CLUSTER_LOSS.MARGIN = 10.0
_C.LOSSES.CLUSTER_LOSS.USE_GPU = True
_C.LOSSES.CLUSTER_LOSS.ORDERED = True

# histogram_loss
_C.LOSSES.HISTOGRAM_LOSS = CN()
_C.LOSSES.CLUSTER_LOSS.NUM_STEPS = 1000
_C.LOSSES.CLUSTER_LOSS.CUDA = True

# center_loss
_C.LOSSES.CENTER_LOSS = CN()
_C.LOSSES.CENTER_LOSS.CLASS=100 # cars 98, cub 100, sop 11318, clothes 3997
_C.LOSSES.CENTER_LOSS.USE_GPU=True

# crossentropy_loss

# Data option
_C.DATA = CN()
_C.DATA.TRAIN_IMG_SOURCE = 'resource/datasets/CUB_200_2011/train.txt'
_C.DATA.TEST_IMG_SOURCE = 'resource/datasets/CUB_200_2011/test.txt'
_C.DATA.TRAIN_BATCHSIZE = 60
_C.DATA.TEST_BATCHSIZE = 128
_C.DATA.NUM_WORKERS = 8
_C.DATA.NUM_INSTANCES = 5
_C.DATA.KNN = 1
_C.DATA.MEMORY = 5

# Input option
_C.INPUT = CN()

# INPUT CONFIG
_C.INPUT.MODE = 'BGR'
_C.INPUT.PIXEL_MEAN = [104. / 255, 117. / 255, 128. / 255]
_C.INPUT.PIXEL_STD = 3 * [1. / 255]

_C.INPUT.FLIP_PROB = 0.5
_C.INPUT.ORIGIN_SIZE = 256
_C.INPUT.CROP_SCALE = [0.16, 1]
_C.INPUT.CROP_SIZE = 227

# SOLVER
_C.SOLVER = CN()
_C.SOLVER.IS_FINETURN = False
_C.SOLVER.FINETURN_MODE_PATH = ''
_C.SOLVER.MAX_ITERS = 4000
_C.SOLVER.STEPS = [1000, 2000, 3000]
_C.SOLVER.OPTIMIZER_NAME = 'SGD'
_C.SOLVER.BASE_LR = 0.01
_C.SOLVER.BIAS_LR_FACTOR = 1
_C.SOLVER.WEIGHT_DECAY = 0.0005
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.0005
_C.SOLVER.MOMENTUM = 0.9
_C.SOLVER.GAMMA = 0.1
_C.SOLVER.WARMUP_FACTOR = 0.01
_C.SOLVER.WARMUP_ITERS = 200
_C.SOLVER.WARMUP_METHOD = 'linear'
_C.SOLVER.CHECKPOINT_PERIOD = 200
_C.SOLVER.RNG_SEED = 1
_C.SOLVER.accumation_steps = 2

# Logger
_C.LOGGER = CN()
_C.LOGGER.LEVEL = 20
_C.LOGGER.STREAM = 'stdout'

# Validation
_C.VALIDATION = CN()
_C.VALIDATION.VERBOSE = 200
_C.VALIDATION.IS_VALIDATION = True
